Skip to content

Add ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control prediction sets#8939

Open
txmed82 wants to merge 2 commits into
Project-MONAI:devfrom
txmed82:conformal-risk
Open

Add ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control prediction sets#8939
txmed82 wants to merge 2 commits into
Project-MONAI:devfrom
txmed82:conformal-risk

Conversation

@txmed82

@txmed82 txmed82 commented Jun 21, 2026

Copy link
Copy Markdown

Fixes #8935 (part 2 of 2).

Description

Adds image-level conformal risk control to monai/metrics/, the loss-bounded
counterpart to the marginal-coverage ConformalPredictor in #8938.

monai/metrics/conformal_risk.py:

  • ConformalRiskCalibrator: calibrate one global threshold lambda_hat that
    bounds an image-level loss (miscoverage or false_negative) on a held-out
    split via the finite-sample selection of Angelopoulos et al. 2022, giving
    E[L] <= alpha. Handles classification and variable-size segmentation;
    include_background=False drops background voxels.
  • ConformalRiskPredictor: apply lambda_hat and return the prediction-set
    mask plus a per-voxel uncertainty mask (set size > 1).
  • Coverage / SetSize: CumulativeIterationMetrics for the coverage vs.
    set-size trade-off.

API docs added to docs/source/metrics.rst.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown
Contributor

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0138352b-5934-467d-963a-d70dcbadd9c5

📥 Commits

Reviewing files that changed from the base of the PR and between 2a500d7 and 6b0fc27.

📒 Files selected for processing (4)
  • docs/source/metrics.rst
  • monai/metrics/__init__.py
  • monai/metrics/conformal_risk.py
  • tests/metrics/test_conformal_risk.py
✅ Files skipped from review due to trivial changes (1)
  • docs/source/metrics.rst
🚧 Files skipped from review as they are similar to previous changes (3)
  • monai/metrics/init.py
  • tests/metrics/test_conformal_risk.py
  • monai/metrics/conformal_risk.py

📝 Walkthrough

Walkthrough

Adds monai/metrics/conformal_risk.py implementing Conformal Risk Control: miscoverage and false-negative loss functions, ConformalRiskCalibrator accumulating calibration data and selecting thresholds via finite-sample risk bounds, ConformalRiskPredictor converting softmax to boolean prediction sets and uncertainty masks. Includes standalone compute_coverage and compute_set_size evaluation functions and Coverage/SetSize CumulativeIterationMetric subclasses. Exports all symbols via monai/metrics/__init__.py and documents in docs/source/metrics.rst. Comprehensive test suite validates loss functions, calibrator scenarios, predictor contracts, metric aggregation, and error handling.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.92% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and specifically describes the main addition of ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control.
Description check ✅ Passed Description follows template with issue reference, clear implementation details, and accurate checklist items marked.
Linked Issues check ✅ Passed PR fully implements issue #8935 objectives: threshold calibration with loss bounds, per-voxel uncertainty masks, image-level loss control, and metrics for coverage/set-size trade-off.
Out of Scope Changes check ✅ Passed All changes scope to conformal risk control: new metrics module components, documentation, tests, and package-level exports are directly aligned with objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@txmed82 txmed82 force-pushed the conformal-risk branch 2 times, most recently from 85b8a37 to 2fd9a20 Compare June 21, 2026 15:43

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
monai/metrics/conformal_risk.py (4)

253-253: ⚡ Quick win

Add strict=True to zip().

Ensures _scores and _labels have matching lengths; helps catch logic errors early.

-        for scores_i, labels_i in zip(self._scores, self._labels):
+        for scores_i, labels_i in zip(self._scores, self._labels, strict=True):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` at line 253, Add the `strict=True` parameter
to the zip() call that iterates over self._scores and self._labels in the for
loop. This will ensure that both iterables have matching lengths and raise a
ValueError if they don't, helping catch logic errors early.

Source: Linters/SAST tools


50-50: ⚡ Quick win

Dead import: tqdm is never used.

Neither tqdm nor has_tqdm appear elsewhere in the file.

🧹 Remove unused import
-tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` at line 50, The optional_import call for
tqdm on line 50 imports both tqdm and has_tqdm variables, but neither of these
variables is used anywhere else in the conformal_risk.py file. Remove the entire
line containing the tqdm optional_import statement since it introduces dead code
that serves no purpose.

277-279: ⚡ Quick win

Missing docstring for reset().

Per coding guidelines, all definitions should have docstrings.

     def reset(self) -> None:
+        """Clear accumulated calibration data and reset internal state."""
         self._scores, self._labels = [], []
         self._num_classes = None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` around lines 277 - 279, The reset() method
is missing a docstring as required by coding guidelines. Add a docstring to the
reset() method that describes its purpose, which is to clear the internal state
by resetting the scores list, labels list, and num_classes attribute back to
their initial values. Follow the project's docstring format conventions.

Source: Coding guidelines


76-76: ⚖️ Poor tradeoff

Silent clamping may mask upstream label bugs.

Out-of-range labels are silently clamped to [0, C-1]. Consider logging a warning when clamping occurs, so upstream data issues surface during debugging.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` at line 76, The labels_flat.clamp operation
silently adjusts out-of-range labels without any logging, which can hide
upstream data issues. Before applying the clamp operation on labels_flat, detect
if any values fall outside the valid range [0, c-1] by checking the minimum and
maximum values. If out-of-range values are detected, log a warning message that
indicates clamping is occurring and optionally includes details about how many
or what proportion of labels were affected. Then proceed with the clamp
operation as currently written.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 59-65: The function `_set_from_threshold` in the file is dead code
that is never called, not exported, and has no tests associated with it. Remove
the entire function definition including the docstring to clean up the codebase
and reduce maintenance burden.

---

Nitpick comments:
In `@monai/metrics/conformal_risk.py`:
- Line 253: Add the `strict=True` parameter to the zip() call that iterates over
self._scores and self._labels in the for loop. This will ensure that both
iterables have matching lengths and raise a ValueError if they don't, helping
catch logic errors early.
- Line 50: The optional_import call for tqdm on line 50 imports both tqdm and
has_tqdm variables, but neither of these variables is used anywhere else in the
conformal_risk.py file. Remove the entire line containing the tqdm
optional_import statement since it introduces dead code that serves no purpose.
- Around line 277-279: The reset() method is missing a docstring as required by
coding guidelines. Add a docstring to the reset() method that describes its
purpose, which is to clear the internal state by resetting the scores list,
labels list, and num_classes attribute back to their initial values. Follow the
project's docstring format conventions.
- Line 76: The labels_flat.clamp operation silently adjusts out-of-range labels
without any logging, which can hide upstream data issues. Before applying the
clamp operation on labels_flat, detect if any values fall outside the valid
range [0, c-1] by checking the minimum and maximum values. If out-of-range
values are detected, log a warning message that indicates clamping is occurring
and optionally includes details about how many or what proportion of labels were
affected. Then proceed with the clamp operation as currently written.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e3510c75-6712-4d38-bb82-842dd7c3fb24

📥 Commits

Reviewing files that changed from the base of the PR and between 15f5073 and 85b8a37.

📒 Files selected for processing (4)
  • docs/source/metrics.rst
  • monai/metrics/__init__.py
  • monai/metrics/conformal_risk.py
  • tests/metrics/test_conformal_risk.py

Comment on lines +59 to +65
def _set_from_threshold(scores: torch.Tensor, lam: float) -> torch.Tensor:
"""Boolean prediction set ``{ y : score(y) <= lam }``, shape (..., C).

``scores`` is the non-conformity score tensor with class as the last dim;
``lam`` is a scalar threshold. Returns a bool tensor of the same shape.
"""
return scores <= lam

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if _set_from_threshold is used anywhere in the codebase
rg -n '_set_from_threshold' --type py

Repository: Project-MONAI/MONAI

Length of output: 173


🏁 Script executed:

cat -n monai/metrics/conformal_risk.py | head -80 | tail -25

Repository: Project-MONAI/MONAI

Length of output: 1375


🏁 Script executed:

rg '_set_from_threshold' --type py -g '**test**'

Repository: Project-MONAI/MONAI

Length of output: 45


🏁 Script executed:

rg '__all__' monai/metrics/conformal_risk.py -A 20

Repository: Project-MONAI/MONAI

Length of output: 806


Remove unused _set_from_threshold function.

The function is never called, not exported, and untested. Dead code should be removed.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` around lines 59 - 65, The function
`_set_from_threshold` in the file is dead code that is never called, not
exported, and has no tests associated with it. Remove the entire function
definition including the docstring to clean up the codebase and reduce
maintenance burden.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (2)
monai/metrics/conformal_risk.py (1)

277-279: ⚡ Quick win

Add a Google-style docstring for reset.

This public method is missing a docstring.

Proposed fix
     def reset(self) -> None:
+        """Clear accumulated calibration data.
+
+        Returns:
+            None.
+        """
         self._scores, self._labels = [], []
         self._num_classes = None

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/metrics/conformal_risk.py` around lines 277 - 279, The reset method in
the ConformalRisk class is missing a Google-style docstring. Add a docstring to
the reset method that describes its purpose of resetting internal state
variables (_scores, _labels, and _num_classes to their initial values). Include
a brief one-line summary followed by a more detailed description if needed.
Since the method takes no arguments and returns None, focus on describing what
state is being reset and why this operation is performed.

Source: Coding guidelines

tests/metrics/test_conformal_risk.py (1)

187-187: ⚡ Quick win

Assert the returned probs contract.

probs_out is unpacked but unused. Assert it instead of leaving Ruff RUF059.

Proposed fix
         sets, mask, probs_out = predictor(probs)
+        assert_allclose(probs_out, probs, atol=0)
         self.assertEqual(sets.shape, (1, 3, 3))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/metrics/test_conformal_risk.py` at line 187, The variable `probs_out`
is unpacked from the predictor call but remains unused, which triggers the
RUF059 linting rule. Replace the unused variable with an assertion that
validates the returned `probs_out` meets the expected contract, such as checking
its shape, type, or value ranges that align with the test's expectations. This
both eliminates the unused variable warning and improves test coverage by
verifying the predictor's output contract.

Source: Linters/SAST tools

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 75-76: Instead of silently clamping invalid labels to valid class
indices using the clamp operation, add validation to reject or raise an error
when labels contain invalid values outside the expected range of 0 to c-1. This
needs to be fixed in two locations: the labels_flat.clamp call around line 76
and the similar operation around line 226. Replace the clamping logic with a
check that validates all labels fall within the valid range and either raises an
exception or skips invalid samples, ensuring that corrupted labels do not
silently propagate through the loss, coverage, and lambda_hat calculations.
- Around line 259-262: The code materializes the entire lambda grid evaluation
at once in the line with `sets = scores_i.unsqueeze(0) <= lam_grid.view(-1, 1,
1)`, which creates a tensor of shape (n_lam, P_i, C) that can cause
out-of-memory errors for large 3D volumes. Refactor this block to chunk the
lambda grid into smaller batches and iterate through them in a loop, processing
each chunk separately and accumulating the risk contribution from each chunk
into risk_sum. This prevents materializing the full tensor at once while still
computing the correct cumulative risk.
- Around line 41-48: The `__all__` list in the conformal_risk.py file has items
that are not in alphabetical order, which violates Ruff's sorting requirements.
Specifically, swap the positions of "compute_set_size" and "compute_coverage" in
the `__all__` list so that "compute_coverage" appears before "compute_set_size",
matching alphabetical order.
- Around line 193-197: The validation of lam_grid is incomplete and allows two
problematic cases: empty grids and unsorted grids. In the validation condition
that checks if lam_grid is a 1-D tensor with values in [0, 1], add two
additional checks: verify that lam_grid is not empty (check the length or size),
and verify that lam_grid is sorted in ascending order (you can use
torch.is_nonincreasing or check if the differences between consecutive elements
are non-negative). Update the error message in the ValueError to reflect all
validation requirements. This will prevent crashes at line 271 and ensure
correct behavior at line 273 when finding the infimum.
- Line 262: After calling self.loss_fn with sets_shaped and labels_rep on the
line where risk_sum is updated, add validation to ensure the output is valid.
Validate that the returned loss tensor has the expected shape matching
labels_rep, contains no NaN values, and all values are within the valid range of
0 to 1. If any validation fails, raise a descriptive error to prevent silent
failures in the CRC bound calculation. This validation should occur immediately
after the loss_fn call and before accumulating the result into risk_sum.
- Around line 319-323: The set_threshold method currently accepts tensors of
arbitrary shape, but it should enforce that the lam parameter is a scalar tensor
with a value in the range [0, 1]. After the existing isinstance check for
torch.Tensor, add validation to ensure the tensor has a scalar shape (empty
dimensions) and that its value falls within [0, 1], raising a ValueError if
either condition is violated. This prevents invalid broadcasts over spatial
dimensions that could produce incorrect conformal sets.
- Line 253: In the for loop iterating over self._scores and self._labels using
zip(), add the parameter strict=True to the zip() call to enforce that both
sequences have the same length. This will catch any synchronization issues
between these two buffers if they become out of sync in the future due to
unequal append operations.

---

Nitpick comments:
In `@monai/metrics/conformal_risk.py`:
- Around line 277-279: The reset method in the ConformalRisk class is missing a
Google-style docstring. Add a docstring to the reset method that describes its
purpose of resetting internal state variables (_scores, _labels, and
_num_classes to their initial values). Include a brief one-line summary followed
by a more detailed description if needed. Since the method takes no arguments
and returns None, focus on describing what state is being reset and why this
operation is performed.

In `@tests/metrics/test_conformal_risk.py`:
- Line 187: The variable `probs_out` is unpacked from the predictor call but
remains unused, which triggers the RUF059 linting rule. Replace the unused
variable with an assertion that validates the returned `probs_out` meets the
expected contract, such as checking its shape, type, or value ranges that align
with the test's expectations. This both eliminates the unused variable warning
and improves test coverage by verifying the predictor's output contract.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7ae55399-50d4-479b-8bf4-a56352f5e0cf

📥 Commits

Reviewing files that changed from the base of the PR and between 85b8a37 and 2fd9a20.

📒 Files selected for processing (4)
  • docs/source/metrics.rst
  • monai/metrics/__init__.py
  • monai/metrics/conformal_risk.py
  • tests/metrics/test_conformal_risk.py
✅ Files skipped from review due to trivial changes (2)
  • monai/metrics/init.py
  • docs/source/metrics.rst

Comment thread monai/metrics/conformal_risk.py
Comment thread monai/metrics/conformal_risk.py Outdated
Comment thread monai/metrics/conformal_risk.py
Comment thread monai/metrics/conformal_risk.py Outdated
Comment thread monai/metrics/conformal_risk.py Outdated
Comment thread monai/metrics/conformal_risk.py Outdated
Comment thread monai/metrics/conformal_risk.py
txmed82 added 2 commits June 22, 2026 09:28
…risk control prediction sets

Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
… lambda loop; sort __all__; scalar/range check set_threshold; zip(strict=True); docstrings

- conformal_risk.py: reject out-of-range labels instead of silent clamp (lines ~76, ~226)
- conformal_risk.py: chunk the lambda grid in calibrate() to avoid materializing (n_lam, P_i, C) at once
- conformal_risk.py: validate lam_grid is non-empty and sorted ascending (prevents IndexError and wrong infimum)
- conformal_risk.py: validate loss_fn output shape and NaN after each call
- conformal_risk.py: enforce set_threshold lam is scalar in [0, 1]
- conformal_risk.py: zip(strict=True) over _scores/_labels
- conformal_risk.py: alphabetical __all__ (RUF022), reset() docstring
- test_conformal_risk.py: assert predictor returns input probs unchanged (RUF059)

Signed-off-by: Colin Son <txmed82@users.noreply.github.com>
@Whatsonyourmind

Copy link
Copy Markdown

Reviewed this alongside #8938 — the risk-control math is correct, including the parts that are easy to get subtly wrong:

  • the finite-sample selection (n*R_hat(lambda) + B)/(n+1) <= alpha is implemented exactly as R_hat <= ((n+1)*alpha - B)/n;
  • within[0] is the right infimum, given the empirical risk is non-increasing in lambda;
  • the grid discretization is conservative — the selected lambda is >= the true continuous infimum, and since risk is non-increasing a larger lambda only lowers risk, so realized risk stays <= alpha;
  • the alpha < 1/(n+1) -> full-set fallback is right.

Controlling an image-level loss (rather than per-voxel) is also the principled choice here: image exchangeability is exactly what the CRC theorem needs, so within-image voxel correlation doesn't break the guarantee (unlike a per-voxel marginal treatment).

One robustness item and minor nits:

1. A custom non-monotone loss silently breaks the guarantee. CRC (Angelopoulos et al. 2022, Thm 1) requires the loss to be non-increasing in lambda; both built-ins satisfy that, but the loss: Callable escape hatch lets a user pass one that doesn't — and then within[0] is no longer an infimum and E[L] <= alpha no longer holds, with no error raised. Two cheap options: (a) state the "must be non-increasing in lambda" precondition prominently on the loss arg, and/or (b) after building emp_risk over the grid, assert/warn if it isn't non-increasing ((emp_risk[1:] <= emp_risk[:-1] + tol).all()), which catches a bad custom loss essentially for free.

2. Nits.

  • B = 1 is the correct/tight bound for both built-ins (their range is exactly [0, 1]); only worth exposing as a parameter if you later add a loss with a smaller range.
  • The all-background continue (the image still counts in n but contributes 0 risk) is a reasonable definitional choice — worth a one-line comment so it isn't read as a missed accumulation.
  • Same # ponytail: placeholder comments as Add ConformalPredictor and ConformalCalibrator for split-conformal (LAC) prediction sets #8938.

Solid PR — glad to see CRC going into MONAI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feat Request: Conformal prediction

2 participants